#include "xpu_kernels.h"
#include <bit>
#include <cmath>
#include <iostream>

#include <sycl/sycl.hpp>

inline float dDequantizeFP4(unsigned char val) {
    if ((val & 0b1000) == 8)
        if ((val & 0b0100) == 4)
            if ((val & 0b0010) == 2)
                if ((val & 0b0001) == 1)
                    return -0.25000000f;
                else
                    return -0.16666667f;
            else if ((val & 0b0001) == 1)
                return -0.50000000f;
            else
                return -0.33333333f;
        else if ((val & 0b0010) == 2)
            if ((val & 0b0001) == 1)
                return -1.00000000f;
            else
                return -0.66666667f;
        else if ((val & 0b0001) == 1)
            return -5.208333333e-03f;
        else
            return 0.00000000f;
    else if ((val & 0b0100) == 4)
        if ((val & 0b0010) == 2)
            if ((val & 0b0001) == 1)
                return 0.25000000f;
            else
                return 0.16666667f;
        else if ((val & 0b0001) == 1)
            return 0.50000000f;
        else
            return 0.33333333f;
    else if ((val & 0b0010) == 2)
        if ((val & 0b0001) == 1)
            return 1.00000000f;
        else
            return 0.66666667f;
    else if ((val & 0b0001) == 1)
        return 5.208333333e-03f;
    else
        return 0.00000000f;
}

inline float dDequantizeNF4(unsigned char val) {

    // the values for this tree was generated by test_normal_map_tree
    // in the file tests/test_functional.py
    if ((val & 0b1000) == 8)
        if ((val & 0b0100) == 4)         // 1
            if ((val & 0b0010) == 2)     // 11
                if ((val & 0b0001) == 1) // 111
                    return 1.0f;         //*1111
                else
                    return 0.7229568362236023f; //*1110
            else if ((val & 0b0001) == 1)       // 110
                return 0.5626170039176941f;     //*1101
            else
                return 0.44070982933044434f; //*1100
        else if ((val & 0b0010) == 2)        // 10
            if ((val & 0b0001) == 1)         // 101
                return 0.33791524171829224f; //*1011
            else
                return 0.24611230194568634f; //*1010
        else if ((val & 0b0001) == 1)        // 100
            return 0.16093020141124725f;     //*1001
        else
            return 0.07958029955625534f; //*1000

    else if ((val & 0b0100) == 4)    // 0
        if ((val & 0b0010) == 2)     // 01
            if ((val & 0b0001) == 1) // 011
                return 0.0f;         //*0111
            else
                return -0.09105003625154495f; //*0110
        else if ((val & 0b0001) == 1)         // 010
            return -0.18477343022823334f;     //*0101
        else
            return -0.28444138169288635f; //*0100
    else if ((val & 0b0010) == 2)         // 00
        if ((val & 0b0001) == 1)          // 001
            return -0.39491748809814453f; //*0011
        else
            return -0.5250730514526367f; //*0010
    else if ((val & 0b0001) == 1)        // 000
        return -0.6961928009986877f;     //*0001
    else
        return -1.0f; //*0000
}

template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE>
SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::operator()(sycl::nd_item<1> item) const {
    const int base_idx = item.get_group(0) * TILE_SIZE;
    size_t local_idx = item.get_local_id(0) * NUM_PER_TH;
    float local_abs_max = -FLT_MAX;
    int local_load_idx = 0;
    int local_store_idx = 0;

    uint8_t qvals[NUM_PER_TH];
    T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];

    if (DATA_TYPE > 0) {
        local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx);
        local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2);
    } else {
        local_load_idx = sycl::min(TILE_SIZE, n - base_idx);
        local_store_idx = local_load_idx;
    }

    // Avoid expensive division by the blocksize (as blocksize will always be a
    // power-of-2)
    local_abs_max = absmax[(base_idx + local_idx) >> (31 - std::countl_zero<unsigned int>(blocksize))];

    if (local_idx + NUM_PER_TH < local_load_idx) {
        reinterpret_cast<sycl::vec<uint8_t, NUM_PER_TH>(&)[NUM_PER_TH]>(qvals)[0] =
            reinterpret_cast<sycl::vec<uint8_t, NUM_PER_TH>*>(A)[(base_idx + local_idx) / NUM_PER_TH];
    } else {
#pragma unroll NUM_PER_TH
        for (int i = 0; i < NUM_PER_TH; i++) {
            if (local_idx + i < local_load_idx) {
                qvals[i] = A[base_idx + local_idx + i];
            } else {
                qvals[i] = (uint8_t)0;
            }
        }
    }

    switch (DATA_TYPE) {
    case General8bit:
#pragma unroll NUM_PER_TH
        for (int j = 0; j < NUM_PER_TH; j++)
            vals[j] = code[qvals[j]] * local_abs_max;
        break;
    case FP4:
#pragma unroll NUM_PER_TH
        for (int j = 0; j < NUM_PER_TH; j++) {
            vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max;
            vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max;
        }
        break;
    case NF4:
#pragma unroll NUM_PER_TH
        for (int j = 0; j < NUM_PER_TH; j++) {
            vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max;
            vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max;
        }
        break;
    }

    const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH;
    int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx;

    if (local_dst_idx + local_dst_size < local_store_idx) {
        reinterpret_cast<sycl::vec<T, local_dst_size>*>(
            out
        )[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / local_dst_size] =
            reinterpret_cast<sycl::vec<T, local_dst_size>(&)[local_dst_size]>(vals)[0];
    } else {
#pragma unroll NUM_PER_TH
        for (int i = 0; i < local_dst_size; i++) {
            if (local_dst_idx + i < local_store_idx) {
                out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = vals[i];
            }
        }
    }
}

template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD, size_t SUBG_SIZE, int BITS>
SYCL_EXTERNAL void
    kgemv_4bit_inference<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS>::operator()(sycl::nd_item<1> item) const {
    size_t idx = item.get_local_id();
    const int sg_idx = idx / SUBG_SIZE;
    const int sg_lane = idx % SUBG_SIZE;
    const int num_values_4bit = SUBG_SIZE;
    const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx;
    const int offset_B = ldb * row_B;
    const int num_values_8bit = num_values_4bit / 2;
    float local_C = 0.0f;

    unsigned char local_B_4bit[num_values_8bit];
    T local_B[num_values_4bit / 4];
    T local_A[num_values_4bit / 4];
    T local_absmax = T(0.0f);

    if (idx < 16) {
        quant_map[idx] = T(datatype[idx]);
    }

    item.barrier(sycl::access::fence_space::local_space);

    for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; inner_idx += SUBG_SIZE * num_values_4bit) {
        const int inner_idx_halved = inner_idx / 2;

        // Avoid expensive division by the blocksize (as blocksize will always be a
        // power-of-2)
        const int absidx = ((2 * offset_B) + inner_idx) >> (31 - std::countl_zero((unsigned int)blocksize));
        local_absmax = absmax[absidx];

        if (row_B < N) {
            if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
                reinterpret_cast<sycl::vec<int, 4>(&)[num_values_8bit]>(local_B_4bit)[0] =
                    reinterpret_cast<sycl::vec<int, 4>*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
            } else {
#pragma unroll
                for (int j = 0; j < (num_values_8bit); j++)
                    if ((inner_idx_halved) + j < (K / 2))
                        local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
                    else
                        local_B_4bit[j] = 0b01110111;
            }
        } else {
#pragma unroll
            for (int j = 0; j < (num_values_8bit); j++)
                local_B_4bit[j] = 0b01110111;
        }

        for (int i = 0; i < 4; i++) {
#pragma unroll
            for (int k = 0; k < num_values_8bit / 4; k++) {
                local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
                local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
            }

            if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
                if (BITS == 16) {
                    reinterpret_cast<sycl::vec<int, 4>(&)[num_values_4bit / 4]>(local_A)[0] =
                        reinterpret_cast<sycl::vec<int, 4>*>(A)[inner_idx / (num_values_4bit / 4) + i];
                } else {
                    reinterpret_cast<sycl::vec<int, 4>(&)[num_values_4bit / 4]>(local_A)[0] =
                        reinterpret_cast<sycl::vec<int, 4>*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
                    reinterpret_cast<sycl::vec<int, 4>(&)[num_values_4bit / 4]>(local_A)[1] =
                        reinterpret_cast<sycl::vec<int, 4>*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
                }

            } else {
#pragma unroll
                for (int k = 0; k < num_values_4bit / 4; k++)
                    if (inner_idx + (i * num_values_4bit / 4) + k < K)
                        local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
                    else
                        local_A[k] = T(0.0f);
            }

// accumulate in float for accuracy;
#pragma unroll
            for (int k = 0; k < num_values_4bit / 4; k++) {
                local_C += (float)(local_A[k] * local_B[k]);
            }
        }
    }

    local_C = sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>());

    if (row_B < N && sg_lane == 0)
        out[row_B] = T(local_C);
}

//==============================================================
//                   TEMPLATE DEFINITIONS
//==============================================================

template class kDequantizeBlockwise<sycl::half, 512, 4, FP4>;
template class kDequantizeBlockwise<sycl::half, 512, 4, General8bit>;
template class kDequantizeBlockwise<sycl::half, 512, 4, NF4>;

template class kDequantizeBlockwise<float, 512, 4, FP4>;
template class kDequantizeBlockwise<float, 512, 4, General8bit>;
template class kDequantizeBlockwise<float, 512, 4, NF4>;

template class kDequantizeBlockwise<sycl::ext::oneapi::bfloat16, 512, 4, FP4>;
template class kDequantizeBlockwise<sycl::ext::oneapi::bfloat16, 512, 4, General8bit>;
template class kDequantizeBlockwise<sycl::ext::oneapi::bfloat16, 512, 4, NF4>;

template class kgemv_4bit_inference<sycl::half, 128, 4, 32, 16>;
template class kgemv_4bit_inference<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
template class kgemv_4bit_inference<float, 128, 4, 32, 32>;
